import networkx as nx
import numpy as np

def create_network(n, k=None, network_type='Gnp', directed=False, seed=None, cluster=None, 
                   p_in=None, p_out=None, p=None):
    """
    This function generates a social network.
    If the network is undirected, only even <k> allows for use of all network types.
    
    INPUTS:
    - n:   number of individuals in the social system (int).
    - k:   average degree desired in social network (int).
    - network_type: type of network to generate: "Gnp" (Erdos-Renyi), "regular", 
                    "connected_caveman", "grid", or "stochastic_block" (str).    
    - cluster: number of clusters/cliques to generate for caveman and stochastic block modeling.
    - p_in: probability of edges within blocks for stochastic block modeling.
    - p_out: probability of edges between blocks for stochastic block modeling.
    """
    
    if network_type == "Gnp":
        p = k / (n - 1)
        g = nx.erdos_renyi_graph(n, p, seed=seed, directed=directed)

    elif network_type == "regular":
        g = nx.random_regular_graph(d=k, n=n, seed=seed)

    elif network_type == "connected_caveman":
        if cluster is None:
            raise ValueError("Parameter 'cluster' must be provided for 'caveman' network type.")
        k = n // cluster  # Set clique size based on the number of clusters
        g = nx.connected_caveman_graph(cluster, k)
        if directed:
            g = g.to_directed()

    elif network_type == "grid":
        side_length = int(np.sqrt(n))  # assuming n is a perfect square
        if side_length ** 2 != n:
            raise ValueError("For a 2D lattice, n must be a perfect square.")
        g = nx.grid_2d_graph(side_length, side_length, periodic=False)
        g = nx.convert_node_labels_to_integers(g) 
        if directed:
            g = g.to_directed()

    elif network_type == "stochastic_block":
        if cluster is None or p_in is None or p_out is None:
            raise ValueError("Parameters 'cluster', 'p_in', and 'p_out' must be provided for 'stochastic_block' network type.")
        sizes = [n // cluster] * cluster
        probs = [[p_in if i == j else p_out for j in range(cluster)] for i in range(cluster)]
        g = nx.stochastic_block_model(sizes, probs, seed=seed, directed=directed)

    elif network_type == "smallworld":
        g = nx.watts_strogatz_graph(n=n, k=k, p=p, seed=seed)

    else:
        raise ValueError("Invalid network type specified. Choose from 'Gnp', 'regular', 'caveman', 'grid', or 'stochastic_block'.")

    return g
